paddle implements torch.repeat_interleave/K.repeat_elements using paddle.reshape & paddle.tile

Reason is the light and the light of life.

Jerry Su Nov 16, 2021 2 mins

https://pytorch.org/docs/stable/generated/torch.Tensor.repeat.html#torch.Tensor.repeat

https://pytorch.org/docs/stable/generated/torch.repeat_interleave.html

If the repeats is tensor([n1, n2, n3, …]), then the output will be tensor([0, 0, …, 1, 1, …, 2, 2, …, …]) where 0 appears n1 times, 1 appears n2 times, 2 appears n3 times, etc.

torch.repeat

torch.repeat_interleave

import torch
pos_emb = torch.arange(0, 48).reshape((2, 4, 6))   # [batch_size, seq_len, inner_dim]
print(pos_emb.shape)
print(pos_emb)
torch.Size([2, 4, 6])
tensor([[[ 0,  1,  2,  3,  4,  5],
         [ 6,  7,  8,  9, 10, 11],
         [12, 13, 14, 15, 16, 17],
         [18, 19, 20, 21, 22, 23]],

        [[24, 25, 26, 27, 28, 29],
         [30, 31, 32, 33, 34, 35],
         [36, 37, 38, 39, 40, 41],
         [42, 43, 44, 45, 46, 47]]])
sin_pos = pos_emb[..., None, ::2]     # [batch_size, seq_len, 1, inner_dim // 2]
print(sin_pos.shape)
print(sin_pos)                        # 取奇数列
torch.Size([2, 4, 1, 3])
tensor([[[[ 0,  2,  4]],

         [[ 6,  8, 10]],

         [[12, 14, 16]],

         [[18, 20, 22]]],


        [[[24, 26, 28]],

         [[30, 32, 34]],

         [[36, 38, 40]],

         [[42, 44, 46]]]])
sin_pos = sin_pos.repeat_interleave(2, dim=-1)
print(sin_pos.shape)
print(sin_pos)                        # 最后一维复制
torch.Size([2, 4, 1, 6])
tensor([[[[ 0,  0,  2,  2,  4,  4]],

         [[ 6,  6,  8,  8, 10, 10]],

         [[12, 12, 14, 14, 16, 16]],

         [[18, 18, 20, 20, 22, 22]]],


        [[[24, 24, 26, 26, 28, 28]],

         [[30, 30, 32, 32, 34, 34]],

         [[36, 36, 38, 38, 40, 40]],

         [[42, 42, 44, 44, 46, 46]]]])

paddle

通过paddle.tile & paddle.reshape实现torch.repeat_interleave算子

https://github.com/PaddlePaddle/Paddle/issues/37227

import paddle
pos_emb_paddle = paddle.arange(0, 48).reshape((2, 4, 6))   # [batch_size, seq_len, inner_dim]
print(pos_emb_paddle.shape)
print(pos_emb_paddle)
[2, 4, 6]
Tensor(shape=[2, 4, 6], dtype=int64, place=CPUPlace, stop_gradient=True,
       [[[0 , 1 , 2 , 3 , 4 , 5 ],
         [6 , 7 , 8 , 9 , 10, 11],
         [12, 13, 14, 15, 16, 17],
         [18, 19, 20, 21, 22, 23]],

        [[24, 25, 26, 27, 28, 29],
         [30, 31, 32, 33, 34, 35],
         [36, 37, 38, 39, 40, 41],
         [42, 43, 44, 45, 46, 47]]])

pos_emb_paddle = pos_emb_paddle[..., None, ::2].reshape((2, 4, 3, 1)).tile((1, 1, 1, 2)).reshape((2, 4, 1, 6))

pos_emb_paddle = pos_emb_paddle[..., None, ::2].reshape((2,4,3,1)).tile((1, 1, 1, 2)).     # 转换低纬
pos_emb_paddle
Tensor(shape=[2, 4, 3, 1], dtype=int64, place=CPUPlace, stop_gradient=True,
       [[[[0 ],
          [2 ],
          [4 ]],

         [[6 ],
          [8 ],
          [10]],

         [[12],
          [14],
          [16]],

         [[18],
          [20],
          [22]]],


        [[[24],
          [26],
          [28]],

         [[30],
          [32],
          [34]],

         [[36],
          [38],
          [40]],

         [[42],
          [44],
          [46]]]])
pos_emb_paddle = pos_emb_paddle.tile((1, 1, 1, 2))  # 
pos_emb_paddle
Tensor(shape=[2, 4, 3, 2], dtype=int64, place=CPUPlace, stop_gradient=True,
       [[[[0 , 0 ],
          [2 , 2 ],
          [4 , 4 ]],

         [[6 , 6 ],
          [8 , 8 ],
          [10, 10]],

         [[12, 12],
          [14, 14],
          [16, 16]],

         [[18, 18],
          [20, 20],
          [22, 22]]],


        [[[24, 24],
          [26, 26],
          [28, 28]],

         [[30, 30],
          [32, 32],
          [34, 34]],

         [[36, 36],
          [38, 38],
          [40, 40]],

         [[42, 42],
          [44, 44],
          [46, 46]]]])
pos_emb_paddle.reshape((2,4,1,6))
Tensor(shape=[2, 4, 1, 6], dtype=int64, place=CPUPlace, stop_gradient=True,
       [[[[0 , 0 , 2 , 2 , 4 , 4 ]],

         [[6 , 6 , 8 , 8 , 10, 10]],

         [[12, 12, 14, 14, 16, 16]],

         [[18, 18, 20, 20, 22, 22]]],


        [[[24, 24, 26, 26, 28, 28]],

         [[30, 30, 32, 32, 34, 34]],

         [[36, 36, 38, 38, 40, 40]],

         [[42, 42, 44, 44, 46, 46]]]])

Read more:

Related posts: